23  Регрессионные модели с tidymodels

23.1 Регрессионные алгоритмы

В машинном обучении проблемы, связанные с количественным откликом, называют проблемами регрессии, а проблемы, связанные с качественным откликом, проблемами классификации. В прошлом уроке мы познакомились с простой и множественной регрессией, но регрессионных алгоритмов великое множество. Вот лишь некоторые из них:

  1. полиномиальная регрессия: расширение линейной регрессии, позволяющее учитывать нелинейные зависимости.

  2. логистическая регрессия: используется для прогнозирования категориальных (бинарных) откликов.

  3. регрессия на опорных векторах (SVM): ищет гиперплоскость, позволяющую минимизировать ошибку в многомерном пространстве.

  4. деревья регрессии: строят иерархическую древовидную модель, последовательно разбивая данные на подгруппы.

  5. случайный лес: комбинирует предсказания множества деревьев для повышения точности и устойчивости.

Кроме того, существуют методы регуляризации линейных моделей, позволяющие существенно улучшить их качество на данных большой размерности (т.е. с большим количеством предкторов). К таким алгоритмам относятся гребневая регрессия и метод лассо. О них мы поговорим в одном из следующих уроков.

О математической стороне дела см. Г. Джеймс, Д. Уиттон, Т. Хасти, Р. Тибришани (2017). В этом уроке мы научимся работать с различными регрессионными алгоритмами, используя библиотеку tidymodels.

23.2 Библиотека tidymodels

Библиотека tidymodels позволяет обучать модели и оценивать их эффективность с использованием принципов опрятных данных. Она представляет собой набор пакетов R, которые разработаны для работы с машинным обучением и являются частью более широкой экосистемы tidyverse.

Вот некоторые из ключевых пакетов, входящих в состав tidymodels:

  1. parsnip - универсальный интерфейс для различных моделей машинного обучения, который упрощает переключение между разными типами моделей;

  2. recipes - фреймворк для создания и управления “рецептами” предварительной обработки данных перед тренировкой модели;

  3. rsample - инструменты для разделения данных на обучающую и тестовую выборки, а также для кросс-валидации;

  4. tune - функции для оптимизации гиперпараметров моделей машинного обучения;

  5. yardstick - инструменты для оценки производительности моделей;

  6. workflow позволяет объединить различные компоненты модели в единый объект: препроцессинг данных, модель машинного обучения, настройку гиперпараметров.

Мы также будем использовать пакет textrecipes, который представляет собой аналог recipes для текстовых данных.

library(tidyverse)
library(tidymodels)
library(textrecipes)

23.3 Данные

Датасет для этого урока хранит данные о названиях, рейтингах, жанре, цене и числе отзывов на некоторые книги с Amazon. Мы попробуем построить регресионную модель, которая будет предсказывать цену книги.

books  <- readxl::read_xlsx("../files/AmazonBooks.xlsx")
books

Данные не очень опрятны, и прежде всего их надо тайдифицировать.

colnames(books) <- tolower(colnames(books))
books <- books |> 
  rename(rating = `user rating`)

На графике ниже видно, что сильной корреляции между количественными переменными не прослеживается, так что задача перед нами стоит незаурядная. Посмотрим, что можно сделать в такой ситуации.

books |> 
  select_if(is.numeric) |> 
  cor() |> 
  corrplot::corrplot(method = "ellipse")

Мы видим, что количественные предикторы объясняют лишь ничтожную долю дисперсии (чуть более информативен жанр).

summary(lm(price ~ reviews + year + rating + genre, data  = books))

Call:
lm(formula = price ~ reviews + year + rating + genre, data = books)

Residuals:
    Min      1Q  Median      3Q     Max 
-16.472  -5.050  -1.841   2.307  89.686 

Coefficients:
                   Estimate Std. Error t value Pr(>|t|)    
(Intercept)       8.987e+02  2.734e+02   3.287  0.00107 ** 
reviews           7.779e-07  3.181e-05   0.024  0.98050    
year             -4.324e-01  1.370e-01  -3.156  0.00168 ** 
rating           -3.655e+00  1.933e+00  -1.891  0.05909 .  
genreNon Fiction  3.920e+00  8.669e-01   4.522 7.41e-06 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 10.16 on 595 degrees of freedom
Multiple R-squared:  0.06903,   Adjusted R-squared:  0.06277 
F-statistic: 11.03 on 4 and 595 DF,  p-value: 1.235e-08

Посмотрим, можно ли как-то улучшить этот результат. Но сначала оценим визуально связь между ценой, с одной стороны, и годом и жанром, с другой.

g1 <- books |> 
  ggplot(aes(year, price, color = genre, group = genre)) + 
  geom_jitter(show.legend = FALSE, alpha = 0.7) + 
  geom_smooth(method = "lm", se = FALSE) +
  theme_minimal()

g2 <- books |> 
  ggplot(aes(genre, price, color = genre)) + 
  geom_boxplot() + 
  theme_minimal()

gridExtra::grid.arrange(g1, g2, nrow = 1)

23.4 Обучающая и контрольная выборка

Вы уже знаете, при обучении модели мы стремимся к минимизации среднеквадратичной ошибки (MSE), однако в большинстве случаев нас интересует не то, как метод работает на обучающих данных, а то, как он покажет себя на контрольных данных. Чтобы избежать переобучения, очень важно в самом начале разделить доступные наблюдения на две группы.

books_split <- books |> 
  initial_split()

books_train <- training(books_split)
books_test <- testing(books_split)

23.5 Определение модели

Определение модели включает следующие шаги:

  • указывается тип модели на основе ее математической структуры (например, линейная регрессия, случайный лес, KNN и т. д.);

  • указывается механизм для подгонки модели – чаще всего это программный пакет, который должен быть использован, например glmnet. Это самостоятельные модели, и parsnip обеспечивает согласованные интерфейсы, используя их в качестве движков для моделирования.

  • при необходимости объявляется режим модели. Режим отражает тип прогнозируемого результата. Для числовых результатов режимом является регрессия, для качественных - классификация. Если алгоритм модели может работать только с одним типом результатов прогнозирования, например, линейной регрессией, режим уже задан.

23.6 Регрессия на опорных векторах

Support Vector Regression — это метод машинного обучения, основанный на идеях метода опорных векторов (SVM), но адаптированный к задаче регрессии, а не классификации (о чем см. следующий урок).

Вместо поиска разделяющей гиперплоскости между классами (как в классификации), SVR старается найти функцию, которая:

  • игнорирует небольшие отклонения внутри некоторого допустимого порога ε (эпсилон),
  • акцентирует внимание на точках, которые лежат вне этой “трубы”, — это и есть опорные векторы.

В этом заключается отличие от обычной регрессии, которая старается проложить прямую, наиболее близкую ко всем точкам и “наказывает” любое отклонение.

SVR тоже строит линию (или кривую в случае нелинейного ядра), но с другим подходом. Она “довольна”, если предсказание находится в пределах допустимой ошибки ε (эпсилон) от настоящего значения.

SVR концентрируется только на тех точках, что выходят за эту “зону безразличия” или лежат на ее границе — они называются опорными векторами. Именно они определяют форму и положение модели. Остальные точки (в пределах ε) никак не влияют на модель.

23.7 SVR в tidymodels

Функция translate() позволяет понять, как parsnip переводит пользовательский код на язык пакета.

svm_spec <- svm_linear() |>
  set_engine("LiblineaR") |> 
  set_mode("regression")

svm_spec |> 
  translate()
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

Model fit template:
LiblineaR::LiblineaR(x = missing_arg(), y = missing_arg(), type = 11, 
    svr_eps = 0.1)

Пока это просто спецификация модели без данных и без формулы. Добавим ее к воркфлоу.

svm_wflow <- workflow() |> 
  add_model(svm_spec)

svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: None
Model: svm_linear()

── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

23.8 Дизайн переменных

Теперь нам нужен препроцессор. За него отвечает пакет recipes. Если вы не уверены, какие шаги необходимы на этом этапе, можно заглянуть в шпаргалку. В случае с линейной регрессией это может быть логарифмическая трансформация, нормализация, отсев переменных с нулевой дисперсией (zero variance), добавление (impute) недостающих значений или удаление переменных, которые коррелируют с другими переменными.

Вот так выглядит наш первый рецепт. Обратите внимание, что формула записывается так же, как мы это делали ранее внутри функции lm().

books_rec <- recipe(price ~ year + genre + name, 
                    data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name)  |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 

При желании можно посмотреть на результат предобработки.

prep(books_rec, books_train) |> 
  bake(new_data = NULL)

Добавляем препроцессор в воркфлоу.

svm_wflow <- svm_wflow |> 
  add_recipe(books_rec)

svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

23.9 Подгонка модели

Подгоним модель на обучающих данных.

svm_fit <- svm_wflow |>
  fit(data = books_train)

Пакет broom позволяет тайдифицировать модель. Посмотрим на слова, которые приводят к “удорожанию” книг. Видно, что в начале списка – слова, связанные с научными публикациями, что не лишено смысла.

svm_fit |> 
  tidy() |> 
  arrange(-estimate)

Оценим модель на контрольных данных.

pred_data <- tibble(truth = books_test$price,
                    estimate = predict(svm_fit, books_test)$.pred)

books_metrics <- metric_set(rmse, rsq, mae)

books_metrics(pred_data, truth = truth,  estimate = estimate)

23.10 Повторные выборки

Чтобы не распечатывать каждый раз тестовые данные (в идеале мы их используем один, максимум два раза!), задействуется ряд методов, позволяющих оценить ошибку путем исключения части обучающих наблюдений из процесса подгонки модели и последующего применения этой модели к исключенным наблюдениям.

В пакете rsample из библиотеки tidymodels реализованы, среди прочего, следующие методы повторных выборок для оценки производительности моделей машинного обучения:

  1. Метод проверочной выборки – набор наблюдений делится на обучающую и проверочную, или удержанную, выборку (validation set): для этого используется initial_validation_split().

  2. K-кратная перекрестная проверка – наблюдения разбиваются на k групп примерно одинакового размера, первый блок служит в качестве проверочной выборки, а модель подгоняется по остальным k-1 блокам; процедура повторяется k раз: функция vfold_cv().

  3. Перекрестная проверка Монте-Карло – в отличие от предыдущего метода, создается множество случайных разбиений данных на обучающую и тестовую выборки: функция mc_cv().

  4. Бутстреп – отбор наблюдений выполняется с возвращением, т.е. одно и то же наблюдение может встречаться несколько раз: функция bootstraps().

  5. Перекрестная проверка по отдельным наблюдениям (leave-one-out сross-validation): одно наблюдение используется в качестве контрольного, а остальные составляют обучающую выборку; модель подгоняется по n-1 наблюдениям, что повторяется n раз: функция loo_cv().

Эти методы повторных выборок позволяют получить надежные оценки производительности моделей машинного обучения, избегая переобучения и обеспечивая репрезентативность тестовых выборок.

set.seed(05102024)
books_folds <- vfold_cv(books_train, v = 10) 

set.seed(05102024)
svm_rs <- fit_resamples(
  svm_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to 1000, but only 999 was available and selected.
There were issues with some computations   A: x1
→ B | warning: max_tokens was set to 1000, but only 967 was available and selected.
There were issues with some computations   A: x1
→ C | warning: max_tokens was set to 1000, but only 997 was available and selected.
There were issues with some computations   A: x1
→ D | warning: max_tokens was set to 1000, but only 991 was available and selected.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1

Теперь соберем метрики и убедимся, что предыдущая оценка на контрольных данных была слишком оптимистичной. Однако результат не так уж плох: во всяком случае мы смогли добиться заметного улучшения по сравнению с нулевой моделью.

collect_metrics(svm_rs)
svm_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") + 
  theme_minimal() +
  coord_cartesian(xlim = c(0,50), ylim = c(0,50))

23.11 Нулевая модель

Кстати, проверим, какой результат даст нулевая модель.

null_reg <- null_model() |> 
  set_engine("parsnip") |> 
  set_mode("regression")

null_wflow <- workflow() |> 
    add_model(null_reg) |> 
    add_recipe(books_rec)

null_rs <- fit_resamples(
  null_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
  )
→ A | warning: A correlation computation is required, but `estimate` is constant and has 0
               standard deviation, resulting in a divide by 0 error. `NA` will be returned.
→ B | warning: max_tokens was set to 1000, but only 999 was available and selected.
→ C | warning: max_tokens was set to 1000, but only 967 was available and selected.
→ D | warning: max_tokens was set to 1000, but only 997 was available and selected.
→ E | warning: max_tokens was set to 1000, but only 991 was available and selected.
collect_metrics(null_rs)

\(R^2\) в таком случае должен быть NaN.

23.12 Деревья решений: понятия

Деревья решений применяются как для задача регрессии, так и для задач классификации.

Регрессионные деревья строят последовательное разбиение пространства признаков таким образом, чтобы минимизировать среднеквадратичную ошибку (MSE) в каждом из подмножеств.

Для этого данные делятся на группы, в которых отклик (целевое значение) как можно более “однороден”. Каждое разбиение осуществляется на основе признаков (факторов), а в листьях дерева находятся средние значения отклика для соответствующей подгруппы. Вот так, например, может выглядеть предсказание расхода топлива для автомобиля (на основе датасета mtcars).

Деревья легко показать графически, их легко интерпретировать, они хорошо справляются с категориальными предикторами (без создания dummy variables). Они особенно хорошо подходят для тех случаев, когда между откликом и предикторами существует нелинейная и сложная зависимость.

Но деревья страдают от высокой дисперсии, т.е. если мы случайным образом разобьем обучающие данные на две части и построим дерево решений на основе каждой из них, полученные результаты могут оказаться довольно разными.

Чтобы с этим справиться, используют три основных метода: бэггинг, случайный лес и бустинг.

23.13 Бэггинг, случайный лес, бустинг

  1. Бэггинг — это метод построения ансамбля моделей путем:
  • повторного случайного выбора подвыборок из обучающего набора данных (бутстрэп);
  • обучения на каждой из этих подвыборок дерева решений;
  • объединения (агрегации) результатов предсказаний этих моделей (для регрессии – усреднение предсказаний; для классификации: голосование).

Хотя бэггинг может улучшить предсказания многих методов, он особено полезен для деревьев решений.

  1. Случайный лес – это частный случай бэггинга. Каждое дерево обучается на случайной выборке с возвращением (бутстрэп), но при построении дерева выбираются не все признаки, а случайное подмножество признаков. Это снижает корреляцию между деревьями и повышает качество ансамбля.

  2. Бустинг работает похожим образом, но деревья строятся последовательно: каждое дерево выращивается с использованием информации по ранее выращенным деревьям. Бустинг не задействует бутстрэп, деревья обучаются на всем наборе данных. Из-за того, что деревья обучаются последовательно, его сложнее запараллелить.

Случайный лес и бустинг плохо поддаются интерпретации.

23.14 Случайный лес в tidymodels

Уточним, какие движки доступны для случайных лесов.

show_engines("rand_forest")

Создадим спецификацию модели. Деревья используются как в задачах классификации, так и в задачах регрессии, поэтому задействуем функцию set_mode().

rf_spec <- rand_forest(trees = 1000) |> 
  set_engine("ranger") |> 
  set_mode("regression")
rf_wflow <- workflow() |> 
  add_model(rf_spec) |> 
  add_recipe(books_rec)

rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (regression)

Main Arguments:
  trees = 1000

Computational engine: ranger 

Обучение займет чуть больше времени.

rf_rs <- fit_resamples(
  rf_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to 1000, but only 999 was available and selected.
There were issues with some computations   A: x1
→ B | warning: max_tokens was set to 1000, but only 967 was available and selected.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1
→ C | warning: max_tokens was set to 1000, but only 997 was available and selected.
There were issues with some computations   A: x1   B: x1
There were issues with some computations   A: x1   B: x1   C: x1
→ D | warning: max_tokens was set to 1000, but only 991 was available and selected.
There were issues with some computations   A: x1   B: x1   C: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1

Мы видим, что среднеквадратическая ошибка уменьшилась, а доля объясненной дисперсии выросла.

collect_metrics(rf_rs)

Тем не менее на графике можно заметить нечто странное: наша модель систематически переоценивает низкие значения и недооценивает высокие. Это связано с тем, что случайные леса не очень подходят для работы с разреженными данными (Hvitfeldt и Silge 2022).

rf_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") +
  theme_minimal() +
  coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))

23.15 Градиентные бустинговые деревья

Также попробуем построить регрессию с использованием градиентных бустинговых деревьев. В 2023 г. эта техника показала хорошие результаты в эксперименте по датировке греческих документальных папирусов.

xgb_spec <- 
  boost_tree(mtry = 50, trees = 1000)  |> 
  set_engine("xgboost")  |> 
  set_mode("regression")
xgb_wflow <- workflow() |> 
  add_model(xgb_spec) |> 
  add_recipe(books_rec)

xgb_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (regression)

Main Arguments:
  mtry = 50
  trees = 1000

Computational engine: xgboost 

Проводим перекрестную проверку.

xgb_rs <- fit_resamples(
  xgb_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to 1000, but only 999 was available and selected.
→ B | warning: max_tokens was set to 1000, but only 967 was available and selected.
There were issues with some computations   A: x1   B: x1
→ C | warning: max_tokens was set to 1000, but only 997 was available and selected.
There were issues with some computations   A: x1   B: x1
There were issues with some computations   A: x1   B: x1   C: x1
→ D | warning: max_tokens was set to 1000, but only 991 was available and selected.
There were issues with some computations   A: x1   B: x1   C: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1
collect_metrics(xgb_rs)

Метрики неплохие! Но если взглянуть на остатки, можно увидеть что-то вроде буквы S.

rf_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") +
  theme_minimal() +
  coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))

23.16 Удаление стопслов

Изменим рецепт приготовления данных.

stopwords_rec <- function(stopwords_name) {
  recipe(price ~ year + genre + name, data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name)  |> 
  step_stopwords(name, stopword_source = stopwords_name) |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 
}

Создадим воркфлоу.

svm_wflow <- workflow() |> 
  add_model(svm_spec)

И снова проведем перекрестную проверку, на этот раз с разными списками стоп-слов. На этом шаге команда вернет предупреждения о том, что число слов меньше 1000, это нормально, т.к. после удаления стопслов токенов стало меньше.

set.seed(123)
snowball_rs <- fit_resamples(
  svm_wflow |>  add_recipe(stopwords_rec("snowball")),
  books_folds
)

set.seed(234)
smart_rs <- fit_resamples(
  svm_wflow |> add_recipe(stopwords_rec("smart")),
  books_folds
)

set.seed(345)
stopwords_iso_rs <- fit_resamples(
  svm_wflow |> add_recipe(stopwords_rec("stopwords-iso")),
  books_folds
)
collect_metrics(smart_rs)
collect_metrics(snowball_rs)
collect_metrics((stopwords_iso_rs))

В нашем случае удаление стоп-слов положительного эффекта не имело.

word_counts <- tibble(name = c("snowball", "smart", "stopwords-iso")) %>%
  mutate(words = map_int(name, ~length(stopwords::stopwords(source = .))))

list(snowball = snowball_rs,
     smart = smart_rs,
     `stopwords-iso` = stopwords_iso_rs)  |> 
  map_dfr(show_best, metric = "rmse", .id = "name")  |> 
  left_join(word_counts, by = "name")  |> 
  mutate(name = paste0(name, " (", words, " words)"),
         name = fct_reorder(name, words))  |> 
  ggplot(aes(name, mean, color = name)) +
  geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), alpha = 0.6) +
  geom_point(size = 3, alpha = 0.8) +
  theme(legend.position = "none") + 
  theme_minimal()

23.17 Настройки числа n-grams

ngram_rec <- function(ngram_options) {
  recipe(price ~ year + genre + name, data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name, token = "ngrams", options = ngram_options)  |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 
}
fit_ngram <- function(ngram_options) {
  fit_resamples(
    svm_wflow  |> 
    add_recipe(ngram_rec(ngram_options)),
    books_folds
  )
}
set.seed(123)
unigram_rs <- fit_ngram(list(n = 1))
→ A | warning: max_tokens was set to 1000, but only 999 was available and selected.
→ B | warning: max_tokens was set to 1000, but only 967 was available and selected.
→ C | warning: max_tokens was set to 1000, but only 997 was available and selected.
There were issues with some computations   A: x1   B: x1   C: x1
→ D | warning: max_tokens was set to 1000, but only 991 was available and selected.
There were issues with some computations   A: x1   B: x1   C: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1
set.seed(234)
bigram_rs <- fit_ngram(list(n = 2, n_min = 1))

set.seed(345)
trigram_rs <- fit_ngram(list(n = 3, n_min = 1))
collect_metrics(unigram_rs)
collect_metrics(bigram_rs)
collect_metrics(trigram_rs)

Таким образом, униграмы дают лучший результат:

list(`1` = unigram_rs,
     `1 and 2` = bigram_rs,
     `1, 2, and 3` = trigram_rs) |> 
  map_dfr(collect_metrics, .id = "name")  |> 
  filter(.metric == "rmse")  |> 
  ggplot(aes(name, mean, color = name)) +
  geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), 
                alpha = 0.6) +
  geom_point(size = 3, alpha = 0.8) +
  theme(legend.position = "none") +
  labs(
    y = "RMSE"
  ) + 
  theme_minimal()

23.18 Лучшая модель и оценка

svm_fit <- svm_wflow |>
  add_recipe(books_rec) |> 
  fit(data = books_test)
Warning: max_tokens was set to 1000, but only 595 was available and selected.
svm_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
$TypeDetail
[1] "L2-regularized L2-loss support vector regression primal (L2R_L2LOSS_SVR)"

$Type
[1] 11

$W
          year genre_Non.Fiction tfidf_name_1 tfidf_name_10 tfidf_name_100
[1,] -1.086172          4.675565    0.2648891    -0.8700595      -3.515508
     tfidf_name_14 tfidf_name_16 tfidf_name_1936 tfidf_name_2 tfidf_name_2.0
[1,]    0.09937253     0.2080836      -0.4406231   -0.2893487      0.5316141
     tfidf_name_3 tfidf_name_4 tfidf_name_5 tfidf_name_6 tfidf_name_6th
[1,]     2.622101    -1.745463    -1.768856    -1.696583       11.55352
     tfidf_name_7 tfidf_name_8 tfidf_name_a tfidf_name_about tfidf_name_absurd
[1,]    0.5523216    -1.386538    -2.800241        0.4634849          1.038721
     tfidf_name_according tfidf_name_achieving tfidf_name_act
[1,]             9.703266            0.2824648      0.4634849
     tfidf_name_activity tfidf_name_after tfidf_name_afterlife
[1,]           -1.138795        -0.559748            -0.891733
     tfidf_name_aftermath tfidf_name_ages tfidf_name_agreements
[1,]          -0.02803979          -1.847             -1.573297
     tfidf_name_alaska tfidf_name_alex tfidf_name_all tfidf_name_alphabet
[1,]       -0.04722675      -0.4123504       1.732344          -0.7532651
     tfidf_name_am tfidf_name_amazing tfidf_name_america tfidf_name_american
[1,]   -0.01569643          -1.082116         -0.3258213             10.2308
     tfidf_name_americans tfidf_name_an tfidf_name_and tfidf_name_animal
[1,]           -0.4406231      -4.06854     -0.1338292        -0.8972068
     tfidf_name_animals tfidf_name_answers tfidf_name_apologizing
[1,]         -0.5771467           1.038721              0.2824648
     tfidf_name_are tfidf_name_art tfidf_name_as tfidf_name_assassination
[1,]     -0.4564853     -0.8431691    -0.4958947               -0.5068097
     tfidf_name_assault tfidf_name_association tfidf_name_astounding
[1,]          -1.443708               11.55352            -0.5465791
     tfidf_name_astrophysics tfidf_name_at tfidf_name_atomic tfidf_name_awesome
[1,]               -1.292433    -0.2544069         0.1746368         -0.4564853
     tfidf_name_azkaban tfidf_name_back tfidf_name_bad tfidf_name_badass
[1,]           3.269275       -2.528743      0.1746368        -0.4564853
     tfidf_name_ball tfidf_name_ballad tfidf_name_basketball
[1,]      0.09937253          1.719476              9.703266
     tfidf_name_battling tfidf_name_be tfidf_name_bear tfidf_name_beautiful
[1,]         -0.07962897     0.1701213       -1.771094          -0.01569643
     tfidf_name_becoming tfidf_name_bed tfidf_name_belly tfidf_name_berlin
[1,]          -0.1864733      -1.458195        -1.227587        -0.4406231
     tfidf_name_big tfidf_name_bill tfidf_name_billionaires tfidf_name_blood
[1,]      -1.729548       -1.683616                1.389012        -1.131105
     tfidf_name_boat tfidf_name_body tfidf_name_book tfidf_name_books
[1,]      -0.4439841       -1.138795       0.8875906        -1.971245
     tfidf_name_boy's tfidf_name_boys tfidf_name_brave tfidf_name_break
[1,]       -0.5465791      -0.4406231      -0.01569643        0.1746368
     tfidf_name_bree tfidf_name_bringing tfidf_name_brink tfidf_name_brown

...
and 282 more lines.

Взглянем на остатки. Для этого пригодится уже знакомая функция augment() из пакета broom.

svm_res <- augment(svm_fit, new_data = books_test) |> 
  mutate(res = price - .pred) |> 
  select(price, .pred, res)

svm_res
library(gridExtra)

g1 <- svm_res |> 
  mutate(res = price - .pred) |> 
  ggplot(aes(res)) +
  geom_histogram(fill = "steelblue", color  = "white") +
  theme_minimal()

g2 <- svm_res |> 
  ggplot(aes(price, .pred)) +
  geom_jitter(color = "steelblue", alpha = 0.7) +
  geom_abline(linetype = 2, color = "grey80", linewidth = 2) +
  theme_minimal()

grid.arrange(g1, g2, nrow = 1)

Соберем метрики.

books_metrics <- metric_set(rmse, rsq, mae)
books_metrics(svm_res, truth = price,  estimate = .pred)

Также посмотрим, какие слова больше всего связаны с увеличением и с уменьшением цены.

svm_fit |> 
  tidy() |> 
  filter(term != "year") |> 
  filter(!str_detect(term, "genre")) |> 
  mutate(sign = case_when(estimate > 0 ~ "дороже",
                          .default = "дешевле"),
         estimate = abs(estimate), 
         term = str_remove_all(term, "tfidf_name_")) |> 
  group_by(sign) |> 
  top_n(20, estimate) |> 
  ungroup() |> 
  ggplot(aes(x = estimate, y = fct_reorder(term, estimate),
             fill = sign)) +
  geom_col(show.legend = FALSE) +
  scale_x_continuous(expand = c(0,0)) +
  facet_wrap(~sign, scales = "free") +
  labs(y = NULL, 
       title = "Связь слов с ценой книг") +
  theme_minimal()

Любопытно: судя по нашему датасету, конституция США раздается на Амазоне бесплатно.